// Copyright 2007 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Author: Russell L. Smith

// Fast matrix multiply and solve functions.
//
// These functions are competitive with the ones generated by ATLAS, and they
// are much smaller. This is possible because the typical matrix size
// passed to these functions is 12x12 (ATLAS is optimized for somewhat larger
// matrices).

#include "gpo/common.h"
#include "gpo/fast_block.h"

namespace GPO {

// The functions exported by this file can either be the optimized ones
// (for speed) or the reference ones (for performance comparison).

//#define USE_REFERENCE_FUNCTIONS

#ifdef USE_REFERENCE_FUNCTIONS
  #define OPT_SOLVE solve_ref
  #define OPT_MULTIPLY2 multiply2_ref
  #define OPT_MULTIPLY2_LT multiply2_lt_ref
  #define REF_SOLVE solve
  #define REF_MULTIPLY2 multiply2
  #define REF_MULTIPLY2_LT multiply2_lt
#else
  #define OPT_SOLVE solve
  #define OPT_MULTIPLY2 multiply2
  #define OPT_MULTIPLY2_LT multiply2_lt
  #define REF_SOLVE solve_ref
  #define REF_MULTIPLY2 multiply2_ref
  #define REF_MULTIPLY2_LT multiply2_lt_ref
#endif

//***************************************************************************
// Optimized functions

void factor(double *A, int n) {
  double recip[n];
  double *aa = A;
  for (int i = 0; i < n; i++) {
    double *bb = A;
    double *cc = A + (i*(i+1))/2;
    for (int j = 0; j < i; j++) {
      double sum = *cc;
      double *a = aa;
      double *b = bb;
      for (int k = j; k; k--) sum -= (*(a++))*(*(b++));
      *cc = sum * recip[j];
      bb += j+1;
      cc++;
    }
    double sum = *cc;
    double *a = aa;
    for (int k = i; k; k--, a++) sum -= (*a)*(*a);
    *cc = sqrt(sum);            // Will trigger FP exception if sum<0
    recip[i] = 1.0/(*cc);
    aa += i+1;
  }
}


/*
 * This optimization for solve() seems to make very little difference.
 * It handles all right hand sides simultaneously.

void OPT_SOLVE (const double *L, double *B, int n, int m) {
  CHECK (L != B);
  double sums[m],ell,*b,*s;
  int i, j, k, n2 = 2*n, n3 = 3*n, n4 = 4*n;
  for (i=0; i<n; i++) {
    memset (sums,0,m*sizeof(double));
    for (j=0; j<i; j++) {
      ell = L[0];
      b = B + j;
      s = sums;
      for (k = m; k > 3; k -= 4) {
        s[0] += ell*b[0];
        s[1] += ell*b[n];
        s[2] += ell*b[n2];
        s[3] += ell*b[n3];
        s += 4;
        b += n4;
      }
      for (; k > 0; k--) {
        (*s++) += ell*(*b);
        b += n;
      }
      L++;
    }
    ell = 1.0 / L[0];
    b = B + i;
    s = sums;
    for (k = m; k > 3; k -= 4) {
      b[0] = (b[0] - s[0]) * ell;
      b[n] = (b[n] - s[1]) * ell;
      b[n2] = (b[n2] - s[2]) * ell;
      b[n3] = (b[n3] - s[3]) * ell;
      s += 4;
      b += n4;
    }
    for (; k > 0; k--) {
      b[0] = (b[0] - (*s++)) * ell;
      b += n;
    }
    L++;
  }
}

*/


void OPT_SOLVE(const double *L, double *B, int n, int m) {
  // Solve for right hand sides one at a time, within each solve use
  // a 3x1 block. Precompute reciprocals of the diagonal of L.

  double recip[n];

  const double *ell = L;
  for (int i = 0; i < n; i++) {
    recip[i] = 1.0 / (*ell);
    ell += i+2;
  }

  for (int rhs = 0; rhs < m; rhs++) {
    int i;
    for (i = 0; i <= n-3; i += 3) {
      int i2 = i*2;
      double sum1 = 0;
      double sum2 = 0;
      double sum3 = 0;
      ell = L + (i*(i+1))/2;
      double *ex = B;
      int j;
      for (j = i-2; j >= 0; j -= 2) {
        double p1 = ell[0];
        double q1 = ex[0];
        double p2 = ell[i+1];
        double p3 = ell[i2+3];
        sum1 += p1 * q1;
        sum2 += p2 * q1;
        sum3 += p3 * q1;
        p1 = ell[1];
        q1 = ex[1];
        p2 = ell[i+2];
        p3 = ell[i2+4];
        sum1 += p1 * q1;
        sum2 += p2 * q1;
        sum3 += p3 * q1;
        ell += 2;
        ex += 2;
      }
      j += 2;
      for (; j > 0; j--) {
        double p1 = ell[0];
        double q1 = ex[0];
        double p2 = ell[i+1];
        double p3 = ell[i2+3];
        sum1 += p1 * q1;
        sum2 += p2 * q1;
        sum3 += p3 * q1;
        ell += 1;
        ex += 1;
      }
      sum1 = (ex[0] - sum1) * recip[i];
      ex[0] = sum1;
      double p1 = ell[i+1];
      sum2 = (ex[1] - sum2 - p1*sum1) * recip[i+1];
      ex[1] = sum2;
      p1 = ell[i2+3];
      double p2 = ell[i2+4];
      sum3 = (ex[2] - sum3 - p1*sum1 - p2*sum2) * recip[i+2];
      ex[2] = sum3;
    }
    for (; i < n; i++) {
      double sum1 = 0;
      ell = L + (i*(i+1))/2;
      double *ex = B;
      int j;
      for (j = i-2; j >= 0; j -= 2) {
        double p1 = ell[0];
        double q1 = ex[0];
        sum1 += p1 * q1;
        p1 = ell[1];
        q1 = ex[1];
        sum1 += p1 * q1;
        ell += 2;
        ex += 2;
      }
      j += 2;
      for (; j > 0; j--) {
        double p1 = ell[0];
        double q1 = ex[0];
        sum1 += p1 * q1;
        ell += 1;
        ex += 1;
      }
      sum1 = (ex[0] - sum1) * recip[i];
      ex[0] = sum1;
    }
    B += n;
  }
}


void solveT(const double *L, double *B, int n) {
  CHECK(L != B);
  L += (n*(n+1))/2-1;
  for (int i = n-1; i >= 0; i--) {
    double sum = 0;
    const double *ell = L;
    for (int k = n-1; k > i; k--) {
      sum += (*ell)*B[k];
      ell -= k;
    }
    B[i] = (B[i]-sum)/(*ell);
    L--;
  }
}


void multiply1(double *A, double scale, const double *B, const double *C,
               int p, int q, int r) {
  if (p == 0 || q == 0 || r == 0) return;
  CHECK(A != B && A != C);
  CHECK(A && B && C && p > 0 && q > 0 && r > 0);
  for (int i = 0; i < p; i++) {
    for (int j = 0; j < r; j++) {
      double sum = 0;
      for (int k = 0; k < q; k++) sum += B[i+k*p] * C[j+k*r];
      A[i*r+j] += scale*sum;
    }
  }
}


void OPT_MULTIPLY2(double *A, double scale, const double *B, const double *C,
                   int p, int q, int r) {
  // Optimizations:
  //   * Use a 2x2 outer product to halve the number of loads
  //   * Unroll inner loops by 4.
  //   * Simplify index computations in inner loops.
  if (p == 0 || q == 0 || r == 0) return;
  CHECK(A != B && A != C);
  int p2 = p&~1, r2 = r&~1;
  int i;
  for (i = 0; i < p2; i += 2) {
    int j;
    for (j = 0; j < r2; j += 2) {
      const double *b = B+i*q;
      const double *c = C+j*q;
      double sum11 = 0;
      double sum12 = 0;
      double sum21 = 0;
      double sum22 = 0;
      int k;
      for (k = q; k > 3; k -= 4) {
        double loadb0 = b[0];
        double loadc0 = c[0];
        sum11 += loadb0 * loadc0;
        double loadb1 = b[q];
        sum21 += loadb1 * loadc0;
        double loadc1 = c[q];
        sum12 += loadb0 * loadc1;
        loadb0 = b[1];
        sum22 += loadb1 * loadc1;
        loadc0 = c[1];
        sum11 += loadb0 * loadc0;
        loadb1 = b[q+1];
        sum21 += loadb1 * loadc0;
        loadc1 = c[q+1];
        sum12 += loadb0 * loadc1;
        loadb0 = b[2];
        sum22 += loadb1 * loadc1;
        loadc0 = c[2];
        sum11 += loadb0 * loadc0;
        loadb1 = b[q+2];
        sum21 += loadb1 * loadc0;
        loadc1 = c[q+2];
        sum12 += loadb0 * loadc1;
        loadb0 = b[3];
        sum22 += loadb1 * loadc1;
        loadc0 = c[3];
        sum11 += loadb0 * loadc0;
        loadb1 = b[q+3];
        sum21 += loadb1 * loadc0;
        loadc1 = c[q+3];
        c += 4;
        b += 4;
        sum12 += loadb0 * loadc1;
        sum22 += loadb1 * loadc1;
      }
      for (; k > 0; k--) {
        double loadb0 = b[0];
        double loadc0 = c[0];
        sum11 += loadb0 * loadc0;
        double loadb1 = b[q];
        sum21 += loadb1 * loadc0;
        double loadc1 = c[q];
        c++;
        b++;
        sum12 += loadb0 * loadc1;
        sum22 += loadb1 * loadc1;
      }
      double *a = A + i*r + j;
      a[0] += scale * sum11;
      a[1] += scale * sum12;
      a[r] += scale * sum21;
      a[r+1] += scale * sum22;
    }
    if (j < r) {
      const double *b = B+i*q;
      const double *c = C+j*q;
      double sum11 = 0;
      double sum21 = 0;
      int k;
      for (k = q; k > 3; k -= 4) {
        double loadc0 = c[0];
        sum11 += b[0] * loadc0;
        sum21 += b[q] * loadc0;
        loadc0 = c[1];
        sum11 += b[1] * loadc0;
        sum21 += b[q+1] * loadc0;
        loadc0 = c[2];
        sum11 += b[2] * loadc0;
        sum21 += b[q+2] * loadc0;
        loadc0 = c[3];
        c += 4;
        sum11 += b[3] * loadc0;
        sum21 += b[q+3] * loadc0;
        b += 4;
      }
      for (; k > 0; k--) {
        double loadc0 = c[0];
        c++;
        sum11 += b[0] * loadc0;
        sum21 += b[q] * loadc0;
        b++;
      }
      double *a = A + i*r + j;
      a[0] += scale * sum11;
      a[r] += scale * sum21;
    }
  }
  if (i < p) {
    int j;
    for (j = 0; j < r2; j += 2) {
      const double *b = B+i*q;
      const double *c = C+j*q;
      double sum11 = 0;
      double sum12 = 0;
      int k;
      for (k = q; k > 3; k -= 4) {
        double loadb0 = b[0];
        sum11 += loadb0 * c[0];
        sum12 += loadb0 * c[q+0];
        loadb0 = b[1];
        sum11 += loadb0 * c[1];
        sum12 += loadb0 * c[q+1];
        loadb0 = b[2];
        sum11 += loadb0 * c[2];
        sum12 += loadb0 * c[q+2];
        loadb0 = b[3];
        b += 4;
        sum11 += loadb0 * c[3];
        sum12 += loadb0 * c[q+3];
        c += 4;
      }
      for (; k > 0; k--) {
        double loadb0 = b[0];
        b++;
        sum11 += loadb0 * c[0];
        sum12 += loadb0 * c[q+0];
        c++;
      }
      double *a = A + i*r + j;
      a[0] += scale * sum11;
      a[1] += scale * sum12;
    }
    if (j < r) {
      const double *b = B+i*q;
      const double *c = C+j*q;
      double sum11 = 0;
      int k;
      for (k = q; k > 3; k -= 4) {
        sum11 += b[0] * c[0];
        sum11 += b[1] * c[1];
        sum11 += b[2] * c[2];
        sum11 += b[3] * c[3];
        b += 4;
        c += 4;
      }
      for (; k > 0; k--) {
        sum11 += b[0] * c[0];
        b++;
        c++;
      }
      A[i*r+j] += scale * sum11;
    }
  }
}


void OPT_MULTIPLY2_LT(double *A, double scale, const double *B, const double *C,
                      int p, int q) {
  // Optimizations:
  //   * Use a 2x2 outer product to halve the number of loads
  //   * Unroll inner loops by 4.
  //   * Simplify index computations in inner loops.
  if (p == 0 || q == 0) return;
  CHECK(A != B && A != C);
  int p2 = p&~1;
  int i;
  for (i = 0; i < p2; i += 2) {
    int j;
    for (j = 0; j < i; j += 2) {
      double sum11 = 0;
      double sum12 = 0;
      double sum21 = 0;
      double sum22 = 0;
      const double *b = B + i*q;
      const double *c = C + j*q;
      int k;
      for (k = q; k > 3; k -= 4) {
        double loadb0 = b[0];
        double loadc0 = c[0];
        sum11 += loadb0 * loadc0;
        double loadb1 = b[q];
        sum21 += loadb1 * loadc0;
        double loadc1 = c[q];
        sum12 += loadb0 * loadc1;
        loadb0 = b[1];
        sum22 += loadb1 * loadc1;
        loadc0 = c[1];
        sum11 += loadb0 * loadc0;
        loadb1 = b[q+1];
        sum21 += loadb1 * loadc0;
        loadc1 = c[q+1];
        sum12 += loadb0 * loadc1;
        loadb0 = b[2];
        sum22 += loadb1 * loadc1;
        loadc0 = c[2];
        sum11 += loadb0 * loadc0;
        loadb1 = b[q+2];
        sum21 += loadb1 * loadc0;
        loadc1 = c[q+2];
        sum12 += loadb0 * loadc1;
        loadb0 = b[3];
        sum22 += loadb1 * loadc1;
        loadc0 = c[3];
        sum11 += loadb0 * loadc0;
        loadb1 = b[q+3];
        sum21 += loadb1 * loadc0;
        loadc1 = c[q+3];
        b += 4;
        c += 4;
        sum12 += loadb0 * loadc1;
        sum22 += loadb1 * loadc1;
      }
      for (; k > 0; k--) {
        double loadb0 = b[0];
        double loadc0 = c[0];
        sum11 += loadb0 * loadc0;
        double loadb1 = b[q];
        sum21 += loadb1 * loadc0;
        double loadc1 = c[q];
        sum12 += loadb0 * loadc1;
        sum22 += loadb1 * loadc1;
        b++;
        c++;
      }
      A[0] += scale * sum11;
      A[1] += scale * sum12;
      A[i+1] += scale * sum21;
      A[i+2] += scale * sum22;
      A += 2;
    }
    double sum11 = 0;
    double sum21 = 0;
    double sum22 = 0;
    const double *b = B + i*q;
    const double *c = C + i*q;
    int k;
    for (k = q; k > 3; k -= 4) {
      double loadb0 = b[0];
      double loadc0 = c[0];
      sum11 += loadb0 * loadc0;
      double loadb1 = b[q];
      sum21 += loadb1 * loadc0;
      double loadc1 = c[q];
      sum22 += loadb1 * loadc1;
      loadb0 = b[1];
      loadc0 = c[1];
      sum11 += loadb0 * loadc0;
      loadb1 = b[q+1];
      sum21 += loadb1 * loadc0;
      loadc1 = c[q+1];
      sum22 += loadb1 * loadc1;
      loadb0 = b[2];
      loadc0 = c[2];
      sum11 += loadb0 * loadc0;
      loadb1 = b[q+2];
      sum21 += loadb1 * loadc0;
      loadc1 = c[q+2];
      sum22 += loadb1 * loadc1;
      loadb0 = b[3];
      loadc0 = c[3];
      sum11 += loadb0 * loadc0;
      loadb1 = b[q+3];
      sum21 += loadb1 * loadc0;
      loadc1 = c[q+3];
      sum22 += loadb1 * loadc1;
      b += 4;
      c += 4;
    }
    for (; k > 0; k--) {
      double loadb0 = b[0];
      double loadb1 = b[q];
      double loadc0 = c[0];
      double loadc1 = c[q];
      sum11 += loadb0 * loadc0;
      sum21 += loadb1 * loadc0;
      sum22 += loadb1 * loadc1;
      b++;
      c++;
    }
    A[0] += scale * sum11;
    A[i+1] += scale * sum21;
    A[i+2] += scale * sum22;

    A += i+3;
  }
  if (i < p) {
    int j;
    for (j = 0; j < i; j += 2) {
      double sum11 = 0;
      double sum12 = 0;
      const double *b = B + i*q;
      const double *c = C + j*q;
      int k;
      for (k = q; k > 3; k -= 4) {
        double loadb0 = b[0];
        sum11 += loadb0 * c[0];
        sum12 += loadb0 * c[q];
        loadb0 = b[1];
        sum11 += loadb0 * c[1];
        sum12 += loadb0 * c[q+1];
        loadb0 = b[2];
        sum11 += loadb0 * c[2];
        sum12 += loadb0 * c[q+2];
        loadb0 = b[3];
        sum11 += loadb0 * c[3];
        sum12 += loadb0 * c[q+3];
        b += 4;
        c += 4;
      }
      for (; k > 0; k--) {
        double loadb0 = b[0];
        sum11 += loadb0 * c[0];
        sum12 += loadb0 * c[q];
        b++;
        c++;
      }
      A[0] += scale * sum11;
      A[1] += scale * sum12;
      A += 2;
    }
    double sum11 = 0;
    const double *b = B + i*q;
    const double *c = C + i*q;
    int k;
    for (k = q; k > 3; k -= 4) {
      sum11 += b[0] * c[0];
      sum11 += b[1] * c[1];
      sum11 += b[2] * c[2];
      sum11 += b[3] * c[3];
      b += 4;
      c += 4;
    }
    for (; k > 0; k --) {
      sum11 += b[0] * c[0];
      b++;
      c++;
    }
    A[0] += scale * sum11;
  }
}

//***************************************************************************
// Reference functions, used for testing the optimized functions.

void REF_SOLVE(const double *L, double *B, int n, int m) {
  CHECK(L != B);
  for (int k = 0; k < m; k++) {
    const double *ell = L;
    for (int i = 0; i < n; i++) {
      double sum = 0;
      for (int j = 0; j < i; j++) sum += (*ell++)*B[j+k*n];
      B[i+k*n] = (B[i+k*n] - sum) / (*ell++);
    }
  }
}


void REF_MULTIPLY2(double *A, double scale, const double *B, const double *C,
                   int p, int q, int r) {
  if (p == 0 || q == 0 || r == 0) return;
  CHECK(A != B && A != C);
  for (int i = 0; i < p; i++) {
    for (int j = 0; j < r; j++) {
      double sum = 0;
      for (int k = 0; k < q; k++) sum += B[i*q+k] * C[j*q+k];
      A[i*r+j] += scale * sum;
    }
  }
}


void REF_MULTIPLY2_LT(double *A, double scale, const double *B, const double *C,
                      int p, int q) {
  if (p == 0 || q == 0) return;
  CHECK(A != B && A != C);
  for (int i = 0; i < p; i++) {
    for (int j = 0; j <= i; j++) {
      double sum = 0;
      for (int k = 0; k < q; k++) sum += B[i*q+k] * C[j*q+k];
      (*A++) += scale * sum;
    }
  }
}

}; // namespace GPO
